# Define functions -------------------------------------------------------------
#   This file is called on main.R after setting the environment. All functions
#   used to clean, process, analyze, and visualize the data reside here, with
#   the exception of inline functions.

# Automatic Data Analysis Function ---------------------------------------------

# Run the main analysis
analyze_data <- function(data, name) {
  message("Starting analysis for Experiment: ", name)
  
  assign("df", data, envir = .GlobalEnv)
  assign("experiment", name, envir = .GlobalEnv)
  assign("path_experiment", paste0("plots/", name), envir = .GlobalEnv)
  
  message("Processing data")
  source("code/variables.R")
  source("code/mapping.R")
  
  message("Fitting models")
  source("code/models.R")
  
  message("Exporting plots")
  source("code/plots.R")
  
  message("Exporting tables")
  source("code/tables.R")
  
  message("Data analysis complete!\n")
}

# Data Processing Functions ----------------------------------------------------

# Convert confidence variable to a factor, categorical, or dummy variable
convert_confidence <- function(var, to = c("ordered", 
                                           "categorical", 
                                           "dummy")) {
  confident <- c("Completely confident", "Confident", "Somewhat confident")
  doubtful <- c("Completely doubtful", "Doubtful", "Somewhat doubtful")
  
  if (to == "ordered") {
    var <- ifelse(grepl("Neither", var), "Neither", var)
    new_var <- factor(var, levels = c(0,1,2,3,4,5,6),
                      labels = c(append(doubtful, append("Neither", rev(confident)))),
                      ordered = TRUE)
  } else if (to == "categorical") {
    new_var <- ifelse(var %in% c(0:2), "Doubtful",
                      ifelse(var %in% c(4:6), "Confident", "Neither")) |>
      factor(levels = c("Doubtful", "Neither", "Confident")) |>
      stats::relevel(ref = "Doubtful")
  } else if (to == "dummy") {
    new_var <- ifelse(var %in% c(0:3), 0, 1) |>
      factor(levels = c(0,1), labels = c("Doubtful", "Confident")) |>
      stats::relevel(ref = "Doubtful")
  }
}

# Create dummy variables for foreign policy orientations (ci, mi, iso)
set_dummy <- function(x, type = "default") {
  q <- quantile(x, probs = c(.25, .5, .75))
  if (type == "default") {
    ifelse(x > q[["50%"]], 1, 0) |> 
      factor(levels = c(0,1), labels = c("No", "Yes"))
  }
  else if (type == "extreme") {
    ifelse(x <= q[["25%"]], 0, ifelse(x >= q[["75%"]], 1, NA)) |> 
      factor(levels = c(0,1,NA), labels = c("Low", "High"))
  }
}

# Rate the feelings toward other countries
rate_feelings <- function(country, type = c("all", "main", "imputed", 
                                            "rival", "ally")) {
  if (type == "all") {
    feeling <- factor(country,
                      levels = c(0,1,2,3,4,5,6, NA), 
                      labels = c("Very Cold", 
                                 "Cold",
                                 "Somewhat Cold",
                                 "Neither cold nor warm",
                                 "Somewhat warm",
                                 "Warm",
                                 "Very warm")) |> 
      stats::relevel(ref = "Neither cold nor warm")
  } else if (type == "main") {
    feeling <- ifelse(country <= 2, 0, ifelse(country == 3, 1, 2)) |> 
      factor(levels = c(0,1,2, NA), 
             labels = c("Cold", 
                        "Neither", 
                        "Warm")) |> 
      stats::relevel(ref = "Neither")
  } else if (type == "imputed") {
    feeling <- ifelse(country <= 2 | is.na(country), 0, ifelse(country == 3, 1, 2)) |>
      factor(levels = c(0,1,2), 
             labels = c("Cold", 
                        "Neither", 
                        "Warm")) |> 
      stats::relevel(ref = "Neither")
  } else if (type == "rival") {
    feeling <- ifelse(country <= 2 | is.na(country), 0, 1) |>
      factor(levels = c(0,1), 
             labels = c("Cold", 
                        "Warm")) |> 
      stats::relevel(ref = "Warm")
  } else if (type == "ally") {
    feeling <- ifelse(country >= 3 | is.na(country), 1, 0) |>
      factor(levels = c(0,1), 
             labels = c("Cold", 
                        "Warm")) |>
      stats::relevel(ref = "Cold")
  }
}

# Model Estimation Functions ---------------------------------------------------

# Fit linear, logistic, ordinal, or multinomial models
fit_models <- function(x, y, reg = c("linear", "logistic", "ordinal", 
                                     "multinomial"), data = df) {
  
  if (reg == "linear") {
    y <- paste0("as.numeric(", y, "_bin)")
    model_fit <- lm(reformulate(x, y), data)
  } else if (reg == "logistic") {
    y <- paste(y, "bin", sep = "_")
    model_fit <- glm(reformulate(x, y), data, family = binomial("logit"))
  } else if (reg == "ordinal") {
    y <- paste(y, "ord", sep = "_")
    model_fit <- polr(reformulate(x, y), data, Hess = TRUE)
  } else if (reg == "multinomial") {
    y <- paste(y, "cat", sep = "_")
    multinom(reformulate(x, y), data, trace = FALSE)
  } else {
    stop("Invalid regression model. Change `reg` value")
  }
}

# Fit linear, logistic, ordinal, or multinomial models in bulk
fit_bulk <- function(iv, y, reg) {
  lapply(iv, \(x) mapply(fit_models, x, y, reg, SIMPLIFY = F))
}

# Fit linear and logistic models for analyzing retaliatory strategies
fit_strategy <- function(x, y, data = df) {
  mlg <- lapply(x, \(x) glm(reformulate(x, y), data, family = binomial("logit")))
  
  y <- paste("as.numeric(", y, ")", sep = "")
  mln <- lapply(x, \(x) lm(reformulate(x, y), data))
  
  models <- list("Logit" = mlg, "OLS" = mln)
}

# Estimate conditional average treatment effects (CATE)
fit_cate <- function(y, group, data = df) {
  y <- paste0("as.numeric(", y, "_bin)")

  inter_1 <- paste(var_base[[1]], group, sep = ":")
  inter_2 <- paste(var_base[[2]], group, sep = ":")
  x <- append(var_full, c(inter_1, inter_2), after = 2)

  lm(reformulate(x, y), data = data)
}

# Marginal Contrast Computation ------------------------------------------------

# Get marginal contrasts
get_diff <- function(model, mod) {
  var_exp <- switch (experiment,
                     "1" = list("Evidence" = list(disavowal = c("None", "Evidence")), 
                                "Hypocrisy" = list(disavowal = c("None", "Hypocrisy")), 
                                "Endorsement" = list(endorsement = c("Not Endorsed", "Endorsed"))),
                     "2" = list("Evidence" = list(disavowal = c("None", "Evidence")), 
                                "Hypocrisy" = list(disavowal = c("None", "Hypocrisy")), 
                                "Endorsement" = list(endorsement = c("Partisan", "Bipartisan"))),
                     "3" = list("EU" = list(endorsement = c("None", "EU")), 
                                "US" = list(endorsement = c("None", "US")),
                                "Israel" = list(accusation = c("Iran", "Israel"))),
                     stop("Invalid experiment. Valid options: 1, 2, 3")
  )
  
  var_list <- lapply(var_exp, \(x) set_diff(model, x, mod))
  do.call(rbind.data.frame, c(var_list, make.row.names = FALSE))
}

# Estimate marginal contrasts
set_diff <- function(model, trt, mod) {
  m <- model |>
    avg_comparisons(variables = trt, by = mod, hypothesis = "pairwise", vcov = "HC3") |>
    subset(select = -c(s.value))
  suppressMessages(
    ggcoef_model(m, tidy_fun = tidy_parameters, show_p_values = FALSE, 
                 signif_stars = FALSE, return_data = TRUE) |>
      within({
        variable <- mod
        var_label <- set_diff_var(mod)
        reference_row <- FALSE
        contrasts <- trt[[1]][[2]]
      }) |>
      set_diff_label())
}

# Define `var_label` for marginal contrasts
set_diff_var <- function(x) {
  label <- if (x %in% c("ci_high", "mi_high", "iso_high")) {
    "Foreign Policy\nOrientations"
  } else if (x %in% c("china_imputed", "russia_imputed", "eu_imputed", "israel_imputed", "iran_imputed")) {
    "Pre-existing\nImages"
  } else {
    str_to_title(x)
  }
}

# Adjust `label` for marginal contrasts
set_diff_label <- function(x) {
  switch (x$variable[[1]],
          "ci_high" = mutate(x, label = paste0("High CI: ", term)),
          "mi_high" = mutate(x, label = paste0("High MI: ", term)),
          "iso_high" = mutate(x, label = paste0("High ISO: ", term)),
          "china_imputed" = mutate(x, label = paste0("China: ", term)),
          "russia_imputed" = mutate(x, label = paste0("Russia: ", term)),
          "israel_imputed" = mutate(x, label = paste0("Israel: ", term)),
          "iran_imputed" = mutate(x, label = paste0("Iran: ", term)),
          "eu_imputed" = mutate(x, label = paste0("EU: ", term)),
          x
  ) 
}

# Visualization of Fitted Model Results ----------------------------------------

# Plot results of the descriptive analysis
plot_descriptive <- function(var, by = "interaction", data = df) {
  perc_df <- calc_perc(var, by, data)
  plt <- perc_df |>
    ggplot(aes(x = perc, 
               y = forcats::fct_reorder2(condition, level == "Doubtful", perc, .desc = TRUE),
               fill = level)) +
    geom_col(position = position_fill(reverse = TRUE), width = .732) +
    geom_text(aes(label = scales::percent(perc, accuracy = 0.1)),
              colour = "black",
              size = 3.5,
              fontface = "bold",
              position = position_fill(vjust = .5, reverse = TRUE)) +
    labs(x = "Percent",
         y = "Treatment Conditions",
         fill = NULL,
         title = if ("Doubtful" %in% perc_df$level) "Confidence in Attribution\nby Treatment Condition" else "Support for Retaliation\nby Treatment Condition") +
    theme_bw() +
    theme(legend.position = "right",
          text = element_text(size = 14),
          title = element_text(face = "bold")) +
    scale_x_continuous(labels = scales::percent) +
    scale_fill_grey(start = .9, end = .6, na.value = "grey40")
}

# Calculate percentages for the descriptive analysis
calc_perc <- function(var, by, data) {
  condition_replc <- if (experiment == "1"){
    c("None.Not Endorsed" = "No endorsement\nand No denial",
      "None.Endorsed" = "EU endorsement\nand No denial",
      "Hypocrisy.Not Endorsed" = "No endorsement\nand Hypocrisy denial",
      "Hypocrisy.Endorsed" = "EU endorsement\nand Hypocrisy denial",
      "Evidence.Not Endorsed" =  "No endorsement\nand Evidence denial",
      "Evidence.Endorsed" = "EU endorsement\nand Evidence denial")
  } else if (experiment == "2") {
    c("None.Bipartisan" = "Bipartisan endorsement\nand No denial",
      "None.Partisan" = "Partisan endorsement\nand No denial",
      "Evidence.Bipartisan" = "Bipartisan endorsement\nand Evidence denial",
      "Evidence.Partisan" = "Partisan endorsement\nand Evidence denial",
      "Hypocrisy.Bipartisan" = "Bipartisan endorsement\nand Hypocrisy denial",
      "Hypocrisy.Partisan" = "Partisan endorsement\nand Hypocrisy denial")
  } else if (experiment == "3") {
    c("Iran.None" = "No endorsement\nand Iranian attribution",
      "Israel.None" = "No endorsement\nand Israeli attribution",
      "Iran.EU" = "EU endorsement\nand Iranian attribution",
      "Israel.EU" = "EU endorsement\nand Israeli attribution",
      "Iran.US" = "US endorsement\nand Iranian attribution",
      "Israel.US" = "US endorsement\nand Israeli attribution")
  }
    
  table(data[, by], data[, var]) |>
    prop.table(margin = 1) |>
    as.data.frame() |>
    dplyr::rename("condition" = "Var1", "level" = "Var2", "perc" = "Freq") |>
    dplyr::mutate(condition = recode_factor(condition, !!!condition_replc)) |>
    dplyr::arrange(condition)
}

# Plot results of the linear models for Average Treatment Effects (ATE) 
plot_ate <- function(model, caption = NULL, x_scale = c(-.5, .5)) {
  if (experiment == "3") {
    variable_include <- c("endorsement", "accusation")
    variable_labels <- c(endorsement = "Endorsement", accusation = "Accusation")
  } else {
    variable_include <- c("endorsement", "disavowal")
    variable_labels <- c(endorsement = "Endorsement", disavowal = "Disavowal")
  }
  
  p <- ggcoef_compare(rev(model),
                      type = "dodged", 
                      tidy_fun = tidy_parameters,
                      variable_labels = variable_labels, 
                      include = variable_include,
                      facet_row = "var_label") +
    ggtitle(caption) +
    xlab("Coefficients") +
    theme(legend.position = "right", 
          text = element_text(size = 14),
          title = element_text(face = "bold")) +
    guides(color = guide_legend(reverse=TRUE)) +
    xlim(x_scale)
  
  suppressMessages({
    p + scale_color_grey(start = 0, end = .7)
  })
}

# Plot results for Conditional Average Treatment Effects (CATE) 
plot_cate <- function(cate_df, caption = NULL, xlim = c(-.5,.5)) {
  cate_df |>
    ggcoef_plot(stripped_rows = FALSE, facet_col = "contrasts") +
    ggtitle(paste0("Pairwise Comparison Between Contrasts:\n", caption)) +
    xlab("Coefficients") +
    scale_color_grey(start = 0, end = .7) +
    theme(legend.position = "right", 
          text = element_text(size = 14),
          title = element_text(face = "bold")) +
    guides(color = guide_legend(reverse=TRUE)) +
    xlim(xlim)
}

# Generate tables for robustness checks
get_table <- function(model, type = "ols", x = NULL, caption = NULL) {
  exponentiate <- FALSE
  gof_map <- c("nobs", "r.squared", "adj.r.squared")
  
  if (type == "log") {
    exponentiate <- TRUE
    gof_map <- c("nobs", "aic", "bic")
  }
  
  coef_map <- switch (experiment,
    "1" = c("endorsementEndorsed" = "EU Endorsement",
            "disavowalEvidence" = "Evidence",
            "disavowalHypocrisy" = "Hypocrisy",
            "disavowalEvidence:endorsementEndorsed" = "Evidence x EU Endorsement",
            "disavowalHypocrisy:endorsementEndorsed" = "Hypocrisy x EU Endorsement"),
    "2" = c("endorsementBipartisan" = "Bipartisan",
            "disavowalEvidence" = "Evidence",
            "disavowalHypocrisy" = "Hypocrisy",
            "disavowalEvidence:endorsementBipartisan" = "Evidence x Bipartisan",
            "disavowalHypocrisy:endorsementBipartisan" = "Hypocrisy x Bipartisan"),
    "3" = c("endorsementEU" = "EU Endorsement",
            "endorsementUS" = "US Endorsement",
            "accusationIsrael" = "Israeli Attribution",
            "endorsementEU:accusationIsrael" = "Israel Attr. x EU End.",
            "endorsementUS:accusationIsrael" = "Israel Attr. x US End."),
    stop("Invalid `experiment` value")
  )
  
  model_cap <- switch (type,
    "ols" = "Results of OLS Models for",
    "log" = "Result of Logistic and Ordered Logistic Models for",
    stop("Invalid `type` value")
  )
  
  caption <- switch (x,
    "at" = paste0(model_cap, " Confidence in Attribution (Experiment ", experiment, ")"),
    "rt" = paste0(model_cap, " Support for Retaliation (Experiment ", experiment, ")"),
    stop("Invalid `x` value")
  )
  
  modelsummary::modelsummary(model, exponentiate = exponentiate, stars = TRUE, 
                             coef_map = coef_map, gof_map = gof_map, fmt = function(x) round(x, 2),
                             title = caption, output = "latex")
}